"""
nuScenes 多视角图像提取及可视化
环境要求：Python 3.6+，需要安装以下包：
  pip install nuscenes-devkit matplotlib pillow
"""

import os
import argparse
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from nuscenes import NuScenes

# 配置参数
DATASET_ROOT = './mini'  # 修改为你的数据集路径
SCENE_INDEX = 0                        # 要提取的场景索引
SAVE_IMAGES = True                     # 是否保存图像到本地
OUTPUT_DIR = './output'                # 输出目录

def main():
    # 创建输出目录
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # 初始化数据集
    if not os.path.exists(DATASET_ROOT):
        raise FileNotFoundError(f"数据集路径 {DATASET_ROOT} 不存在")

    nusc = NuScenes(version='v1.0-mini', dataroot=DATASET_ROOT, verbose=True)

    # 验证场景索引有效性
    if SCENE_INDEX >= len(nusc.scene):
        raise ValueError(f"场景索引 {SCENE_INDEX} 超出范围（总场景数：{len(nusc.scene)}）")

    # 获取指定场景
    scene = nusc.scene[SCENE_INDEX]
 

    # 获取第一个样本（起始时间点）
    sample_token = scene['first_sample_token']
    sample = nusc.get('sample', sample_token)

    # 定义摄像头顺序（环视布局）
    camera_order = [
        'CAM_FRONT_LEFT',   # 前左
        'CAM_FRONT',        # 前
        'CAM_FRONT_RIGHT',  # 前右
        'CAM_BACK_RIGHT',   # 后右
        'CAM_BACK',         # 后
        'CAM_BACK_LEFT'     # 后左
    ]

    # 收集图像数据
    images = {}
    timestamps = []
    for cam_name in camera_order:
        # 获取摄像头数据记录
        cam_token = sample['data'][cam_name]
        cam_data = nusc.get('sample_data', cam_token)
        
        # 加载图像
        img_path = os.path.join(nusc.dataroot, cam_data['filename'])
        img = Image.open(img_path)
        images[cam_name] = img
        timestamps.append(cam_data['timestamp'])

        # 保存原始图像
        if SAVE_IMAGES:
            output_path = os.path.join(OUTPUT_DIR, f"{cam_name}.jpg")
            img.save(output_path)
            print(f"保存 {cam_name} 图像到：{output_path}")

    # 验证时间同步性
    time_diff = (max(timestamps) - min(timestamps)) / 1e3  # 转换为毫秒
    print(f"\n时间同步检查：最大时间差 {time_diff:.2f} 毫秒")
    if time_diff > 50:
        print("警告：不同摄像头时间差较大，可能影响数据一致性")

    # 可视化设置
    plt.figure(figsize=(16, 9))
    plt.suptitle(f"场景 {scene['name']} - 时间戳 {min(timestamps)}", fontsize=16)

    # 绘制单独视角
    for idx, (cam_name, img) in enumerate(images.items()):
        ax = plt.subplot(2, 3, idx+1)
        ax.imshow(img)
        ax.set_title(f"{cam_name}\n{img.size[0]}x{img.size[1]}")
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'multiview.jpg'))
    print(f"\n环视图像已保存至：{os.path.join(OUTPUT_DIR, 'multiview.jpg')}")

    # 创建拼接环视图
    print("\n生成全景拼接视图...")
    target_height = 300  # 统一图像高度
    
    resized_images = []
    for cam_name in camera_order:
        img = images[cam_name]
        ratio = target_height / img.height
        resized_img = img.resize((int(img.width * ratio), target_height))
        resized_images.append(np.array(resized_img))

    # 分两行拼接（前三后三）
    top_row = np.hstack(resized_images[:3])
    bottom_row = np.hstack(resized_images[3:][::-1])  # 反转顺序保持环视连续性
    
    panorama = np.vstack([top_row, bottom_row])

    # 保存全景图
    panorama_path = os.path.join(OUTPUT_DIR, 'panorama.jpg')
    Image.fromarray(panorama).save(panorama_path)
    print(f"全景拼接图已保存至：{panorama_path}")

    # 显示结果（如果环境支持）
    plt.figure(figsize=(20, 8))
    plt.imshow(panorama)
    plt.axis('off')
    plt.title('环视全景拼接视图')
    plt.show()

if __name__ == '__main__':
    main()